{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "H-LANn-hUlZh",
        "outputId": "229f421a-d1f5-445c-c4f2-4f10c25e13e4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Requirement already satisfied: datasets in /usr/local/lib/python3.7/dist-packages (2.6.1)\n",
            "Requirement already satisfied: xxhash in /usr/local/lib/python3.7/dist-packages (from datasets) (3.1.0)\n",
            "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.18.0)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.3.5)\n",
            "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (4.64.1)\n",
            "Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (2022.8.2)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from datasets) (1.21.6)\n",
            "Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.13)\n",
            "Requirement already satisfied: huggingface-hub<1.0.0,>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.10.1)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from datasets) (3.8.3)\n",
            "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2.23.0)\n",
            "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from datasets) (4.13.0)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from datasets) (21.3)\n",
            "Requirement already satisfied: dill<0.3.6 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.5.1)\n",
            "Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0.1)\n",
            "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (4.0.2)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (6.0.2)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (22.1.0)\n",
            "Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (0.13.0)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (4.1.1)\n",
            "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.8.1)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.3.1)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.2.0)\n",
            "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.1.1)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (3.8.0)\n",
            "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->datasets) (3.0.9)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (3.0.4)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (1.25.11)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2022.9.24)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2.10)\n",
            "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->datasets) (3.9.0)\n",
            "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2022.4)\n",
            "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.2)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n",
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (1.21.6)\n",
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (1.3.5)\n",
            "Requirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.7/dist-packages (from pandas) (1.21.6)\n",
            "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas) (2022.4)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas) (1.15.0)\n",
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.23.1)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n",
            "Requirement already satisfied: huggingface-hub<1.0,>=0.10.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.10.1)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.1)\n",
            "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.13.1)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)\n",
            "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.13.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.8.0)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2022.6.2)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.10.0->transformers) (4.1.1)\n",
            "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)\n",
            "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.9.0)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.25.11)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.9.24)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n",
            "Archive:  data_and_models.zip\n",
            "   creating: data_and_models/\n",
            "  inflating: __MACOSX/._data_and_models  \n",
            "  inflating: data_and_models/logistic_model_8.pkl  \n",
            "  inflating: __MACOSX/data_and_models/._logistic_model_8.pkl  \n",
            "  inflating: data_and_models/tfidf_44.pkl  \n",
            "  inflating: __MACOSX/data_and_models/._tfidf_44.pkl  \n",
            "  inflating: data_and_models/tfidf_8.pkl  \n",
            "  inflating: __MACOSX/data_and_models/._tfidf_8.pkl  \n",
            "  inflating: data_and_models/target_corpus.csv  \n",
            "  inflating: __MACOSX/data_and_models/._target_corpus.csv  \n",
            "  inflating: data_and_models/logistic_model_44.pkl  \n",
            "  inflating: __MACOSX/data_and_models/._logistic_model_44.pkl  \n"
          ]
        }
      ],
      "source": [
        "!pip install transformers\n",
        "!pip install datasets\n",
        "!pip install --upgrade --no-cache-dir gdown==4.5.4\n",
        "\n",
        "!gdown 18oZZ4jqRK-uF-Nz6ftRdgNjKix88hrnO\n",
        "!unzip data_and_models.zip && rm data_and_models.zip\n",
        "\n",
        "import numpy as np\n",
        "np.random.seed(11)\n",
        "import torch\n",
        "torch.manual_seed(11)\n",
        "import random\n",
        "random.seed(11)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true,
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "-y9uB6E4gtIa",
        "outputId": "419206fc-1355-445d-ffae-87ac971b1642"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "PyTorch: setting up devices\n",
            "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "# classes 42\n",
            "2915 625 625\n",
            "# classes in train 42\n",
            "# classes in dev 36\n",
            "# classes in test 35\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "loading file vocab.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/vocab.json\n",
            "loading file merges.txt from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/merges.txt\n",
            "loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/tokenizer.json\n",
            "loading file added_tokens.json from cache at None\n",
            "loading file special_tokens_map.json from cache at None\n",
            "loading file tokenizer_config.json from cache at None\n",
            "loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/config.json\n",
            "Model config RobertaConfig {\n",
            "  \"_name_or_path\": \"roberta-base\",\n",
            "  \"architectures\": [\n",
            "    \"RobertaForMaskedLM\"\n",
            "  ],\n",
            "  \"attention_probs_dropout_prob\": 0.1,\n",
            "  \"bos_token_id\": 0,\n",
            "  \"classifier_dropout\": null,\n",
            "  \"eos_token_id\": 2,\n",
            "  \"hidden_act\": \"gelu\",\n",
            "  \"hidden_dropout_prob\": 0.1,\n",
            "  \"hidden_size\": 768,\n",
            "  \"initializer_range\": 0.02,\n",
            "  \"intermediate_size\": 3072,\n",
            "  \"layer_norm_eps\": 1e-05,\n",
            "  \"max_position_embeddings\": 514,\n",
            "  \"model_type\": \"roberta\",\n",
            "  \"num_attention_heads\": 12,\n",
            "  \"num_hidden_layers\": 12,\n",
            "  \"pad_token_id\": 1,\n",
            "  \"position_embedding_type\": \"absolute\",\n",
            "  \"transformers_version\": \"4.23.1\",\n",
            "  \"type_vocab_size\": 1,\n",
            "  \"use_cache\": true,\n",
            "  \"vocab_size\": 50265\n",
            "}\n",
            "\n",
            "loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/config.json\n",
            "Model config RobertaConfig {\n",
            "  \"architectures\": [\n",
            "    \"RobertaForMaskedLM\"\n",
            "  ],\n",
            "  \"attention_probs_dropout_prob\": 0.1,\n",
            "  \"bos_token_id\": 0,\n",
            "  \"classifier_dropout\": null,\n",
            "  \"eos_token_id\": 2,\n",
            "  \"hidden_act\": \"gelu\",\n",
            "  \"hidden_dropout_prob\": 0.1,\n",
            "  \"hidden_size\": 768,\n",
            "  \"id2label\": {\n",
            "    \"0\": \"LABEL_0\",\n",
            "    \"1\": \"LABEL_1\",\n",
            "    \"2\": \"LABEL_2\",\n",
            "    \"3\": \"LABEL_3\",\n",
            "    \"4\": \"LABEL_4\",\n",
            "    \"5\": \"LABEL_5\",\n",
            "    \"6\": \"LABEL_6\",\n",
            "    \"7\": \"LABEL_7\",\n",
            "    \"8\": \"LABEL_8\",\n",
            "    \"9\": \"LABEL_9\",\n",
            "    \"10\": \"LABEL_10\",\n",
            "    \"11\": \"LABEL_11\",\n",
            "    \"12\": \"LABEL_12\",\n",
            "    \"13\": \"LABEL_13\",\n",
            "    \"14\": \"LABEL_14\",\n",
            "    \"15\": \"LABEL_15\",\n",
            "    \"16\": \"LABEL_16\",\n",
            "    \"17\": \"LABEL_17\",\n",
            "    \"18\": \"LABEL_18\",\n",
            "    \"19\": \"LABEL_19\",\n",
            "    \"20\": \"LABEL_20\",\n",
            "    \"21\": \"LABEL_21\",\n",
            "    \"22\": \"LABEL_22\",\n",
            "    \"23\": \"LABEL_23\",\n",
            "    \"24\": \"LABEL_24\",\n",
            "    \"25\": \"LABEL_25\",\n",
            "    \"26\": \"LABEL_26\",\n",
            "    \"27\": \"LABEL_27\",\n",
            "    \"28\": \"LABEL_28\",\n",
            "    \"29\": \"LABEL_29\",\n",
            "    \"30\": \"LABEL_30\",\n",
            "    \"31\": \"LABEL_31\",\n",
            "    \"32\": \"LABEL_32\",\n",
            "    \"33\": \"LABEL_33\",\n",
            "    \"34\": \"LABEL_34\",\n",
            "    \"35\": \"LABEL_35\",\n",
            "    \"36\": \"LABEL_36\",\n",
            "    \"37\": \"LABEL_37\",\n",
            "    \"38\": \"LABEL_38\",\n",
            "    \"39\": \"LABEL_39\",\n",
            "    \"40\": \"LABEL_40\",\n",
            "    \"41\": \"LABEL_41\"\n",
            "  },\n",
            "  \"initializer_range\": 0.02,\n",
            "  \"intermediate_size\": 3072,\n",
            "  \"label2id\": {\n",
            "    \"LABEL_0\": 0,\n",
            "    \"LABEL_1\": 1,\n",
            "    \"LABEL_10\": 10,\n",
            "    \"LABEL_11\": 11,\n",
            "    \"LABEL_12\": 12,\n",
            "    \"LABEL_13\": 13,\n",
            "    \"LABEL_14\": 14,\n",
            "    \"LABEL_15\": 15,\n",
            "    \"LABEL_16\": 16,\n",
            "    \"LABEL_17\": 17,\n",
            "    \"LABEL_18\": 18,\n",
            "    \"LABEL_19\": 19,\n",
            "    \"LABEL_2\": 2,\n",
            "    \"LABEL_20\": 20,\n",
            "    \"LABEL_21\": 21,\n",
            "    \"LABEL_22\": 22,\n",
            "    \"LABEL_23\": 23,\n",
            "    \"LABEL_24\": 24,\n",
            "    \"LABEL_25\": 25,\n",
            "    \"LABEL_26\": 26,\n",
            "    \"LABEL_27\": 27,\n",
            "    \"LABEL_28\": 28,\n",
            "    \"LABEL_29\": 29,\n",
            "    \"LABEL_3\": 3,\n",
            "    \"LABEL_30\": 30,\n",
            "    \"LABEL_31\": 31,\n",
            "    \"LABEL_32\": 32,\n",
            "    \"LABEL_33\": 33,\n",
            "    \"LABEL_34\": 34,\n",
            "    \"LABEL_35\": 35,\n",
            "    \"LABEL_36\": 36,\n",
            "    \"LABEL_37\": 37,\n",
            "    \"LABEL_38\": 38,\n",
            "    \"LABEL_39\": 39,\n",
            "    \"LABEL_4\": 4,\n",
            "    \"LABEL_40\": 40,\n",
            "    \"LABEL_41\": 41,\n",
            "    \"LABEL_5\": 5,\n",
            "    \"LABEL_6\": 6,\n",
            "    \"LABEL_7\": 7,\n",
            "    \"LABEL_8\": 8,\n",
            "    \"LABEL_9\": 9\n",
            "  },\n",
            "  \"layer_norm_eps\": 1e-05,\n",
            "  \"max_position_embeddings\": 514,\n",
            "  \"model_type\": \"roberta\",\n",
            "  \"num_attention_heads\": 12,\n",
            "  \"num_hidden_layers\": 12,\n",
            "  \"pad_token_id\": 1,\n",
            "  \"position_embedding_type\": \"absolute\",\n",
            "  \"transformers_version\": \"4.23.1\",\n",
            "  \"type_vocab_size\": 1,\n",
            "  \"use_cache\": true,\n",
            "  \"vocab_size\": 50265\n",
            "}\n",
            "\n",
            "loading weights file pytorch_model.bin from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/pytorch_model.bin\n",
            "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'lm_head.dense.bias', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']\n",
            "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
            "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/config.json\n",
            "Model config RobertaConfig {\n",
            "  \"architectures\": [\n",
            "    \"RobertaForMaskedLM\"\n",
            "  ],\n",
            "  \"attention_probs_dropout_prob\": 0.1,\n",
            "  \"bos_token_id\": 0,\n",
            "  \"classifier_dropout\": null,\n",
            "  \"eos_token_id\": 2,\n",
            "  \"hidden_act\": \"gelu\",\n",
            "  \"hidden_dropout_prob\": 0.1,\n",
            "  \"hidden_size\": 768,\n",
            "  \"id2label\": {\n",
            "    \"0\": \"LABEL_0\",\n",
            "    \"1\": \"LABEL_1\",\n",
            "    \"2\": \"LABEL_2\",\n",
            "    \"3\": \"LABEL_3\",\n",
            "    \"4\": \"LABEL_4\",\n",
            "    \"5\": \"LABEL_5\",\n",
            "    \"6\": \"LABEL_6\",\n",
            "    \"7\": \"LABEL_7\",\n",
            "    \"8\": \"LABEL_8\",\n",
            "    \"9\": \"LABEL_9\",\n",
            "    \"10\": \"LABEL_10\",\n",
            "    \"11\": \"LABEL_11\",\n",
            "    \"12\": \"LABEL_12\",\n",
            "    \"13\": \"LABEL_13\",\n",
            "    \"14\": \"LABEL_14\",\n",
            "    \"15\": \"LABEL_15\",\n",
            "    \"16\": \"LABEL_16\",\n",
            "    \"17\": \"LABEL_17\",\n",
            "    \"18\": \"LABEL_18\",\n",
            "    \"19\": \"LABEL_19\",\n",
            "    \"20\": \"LABEL_20\",\n",
            "    \"21\": \"LABEL_21\",\n",
            "    \"22\": \"LABEL_22\",\n",
            "    \"23\": \"LABEL_23\",\n",
            "    \"24\": \"LABEL_24\",\n",
            "    \"25\": \"LABEL_25\",\n",
            "    \"26\": \"LABEL_26\",\n",
            "    \"27\": \"LABEL_27\",\n",
            "    \"28\": \"LABEL_28\",\n",
            "    \"29\": \"LABEL_29\",\n",
            "    \"30\": \"LABEL_30\",\n",
            "    \"31\": \"LABEL_31\",\n",
            "    \"32\": \"LABEL_32\",\n",
            "    \"33\": \"LABEL_33\",\n",
            "    \"34\": \"LABEL_34\",\n",
            "    \"35\": \"LABEL_35\",\n",
            "    \"36\": \"LABEL_36\",\n",
            "    \"37\": \"LABEL_37\",\n",
            "    \"38\": \"LABEL_38\",\n",
            "    \"39\": \"LABEL_39\",\n",
            "    \"40\": \"LABEL_40\",\n",
            "    \"41\": \"LABEL_41\"\n",
            "  },\n",
            "  \"initializer_range\": 0.02,\n",
            "  \"intermediate_size\": 3072,\n",
            "  \"label2id\": {\n",
            "    \"LABEL_0\": 0,\n",
            "    \"LABEL_1\": 1,\n",
            "    \"LABEL_10\": 10,\n",
            "    \"LABEL_11\": 11,\n",
            "    \"LABEL_12\": 12,\n",
            "    \"LABEL_13\": 13,\n",
            "    \"LABEL_14\": 14,\n",
            "    \"LABEL_15\": 15,\n",
            "    \"LABEL_16\": 16,\n",
            "    \"LABEL_17\": 17,\n",
            "    \"LABEL_18\": 18,\n",
            "    \"LABEL_19\": 19,\n",
            "    \"LABEL_2\": 2,\n",
            "    \"LABEL_20\": 20,\n",
            "    \"LABEL_21\": 21,\n",
            "    \"LABEL_22\": 22,\n",
            "    \"LABEL_23\": 23,\n",
            "    \"LABEL_24\": 24,\n",
            "    \"LABEL_25\": 25,\n",
            "    \"LABEL_26\": 26,\n",
            "    \"LABEL_27\": 27,\n",
            "    \"LABEL_28\": 28,\n",
            "    \"LABEL_29\": 29,\n",
            "    \"LABEL_3\": 3,\n",
            "    \"LABEL_30\": 30,\n",
            "    \"LABEL_31\": 31,\n",
            "    \"LABEL_32\": 32,\n",
            "    \"LABEL_33\": 33,\n",
            "    \"LABEL_34\": 34,\n",
            "    \"LABEL_35\": 35,\n",
            "    \"LABEL_36\": 36,\n",
            "    \"LABEL_37\": 37,\n",
            "    \"LABEL_38\": 38,\n",
            "    \"LABEL_39\": 39,\n",
            "    \"LABEL_4\": 4,\n",
            "    \"LABEL_40\": 40,\n",
            "    \"LABEL_41\": 41,\n",
            "    \"LABEL_5\": 5,\n",
            "    \"LABEL_6\": 6,\n",
            "    \"LABEL_7\": 7,\n",
            "    \"LABEL_8\": 8,\n",
            "    \"LABEL_9\": 9\n",
            "  },\n",
            "  \"layer_norm_eps\": 1e-05,\n",
            "  \"max_position_embeddings\": 514,\n",
            "  \"model_type\": \"roberta\",\n",
            "  \"num_attention_heads\": 12,\n",
            "  \"num_hidden_layers\": 12,\n",
            "  \"pad_token_id\": 1,\n",
            "  \"position_embedding_type\": \"absolute\",\n",
            "  \"transformers_version\": \"4.23.1\",\n",
            "  \"type_vocab_size\": 1,\n",
            "  \"use_cache\": true,\n",
            "  \"vocab_size\": 50265\n",
            "}\n",
            "\n",
            "loading weights file pytorch_model.bin from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/pytorch_model.bin\n",
            "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'lm_head.dense.bias', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']\n",
            "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
            "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:310: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
            "  FutureWarning,\n",
            "***** Running training *****\n",
            "  Num examples = 2915\n",
            "  Num Epochs = 20\n",
            "  Instantaneous batch size per device = 16\n",
            "  Total train batch size (w. parallel, distributed & accumulation) = 16\n",
            "  Gradient Accumulation steps = 1\n",
            "  Total optimization steps = 3660\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='3660' max='3660' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [3660/3660 25:48, Epoch 20/20]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Accuracy</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>2.430300</td>\n",
              "      <td>2.288778</td>\n",
              "      <td>0.417600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>1.917100</td>\n",
              "      <td>1.934722</td>\n",
              "      <td>0.516800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>1.582500</td>\n",
              "      <td>1.823273</td>\n",
              "      <td>0.542400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>1.362400</td>\n",
              "      <td>1.823743</td>\n",
              "      <td>0.548800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>0.972300</td>\n",
              "      <td>1.889241</td>\n",
              "      <td>0.523200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>0.848000</td>\n",
              "      <td>1.888951</td>\n",
              "      <td>0.534400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>0.573900</td>\n",
              "      <td>2.010903</td>\n",
              "      <td>0.532800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>0.409200</td>\n",
              "      <td>1.989290</td>\n",
              "      <td>0.553600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>0.317400</td>\n",
              "      <td>2.029555</td>\n",
              "      <td>0.563200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>0.230900</td>\n",
              "      <td>2.128701</td>\n",
              "      <td>0.532800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>11</td>\n",
              "      <td>0.196900</td>\n",
              "      <td>2.317828</td>\n",
              "      <td>0.539200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>12</td>\n",
              "      <td>0.136000</td>\n",
              "      <td>2.398464</td>\n",
              "      <td>0.540800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>13</td>\n",
              "      <td>0.049600</td>\n",
              "      <td>2.501200</td>\n",
              "      <td>0.545600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>14</td>\n",
              "      <td>0.136700</td>\n",
              "      <td>2.545634</td>\n",
              "      <td>0.552000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>15</td>\n",
              "      <td>0.057400</td>\n",
              "      <td>2.664694</td>\n",
              "      <td>0.544000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>16</td>\n",
              "      <td>0.082100</td>\n",
              "      <td>2.783216</td>\n",
              "      <td>0.531200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>17</td>\n",
              "      <td>0.014700</td>\n",
              "      <td>2.788378</td>\n",
              "      <td>0.548800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>18</td>\n",
              "      <td>0.063300</td>\n",
              "      <td>2.824948</td>\n",
              "      <td>0.547200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>19</td>\n",
              "      <td>0.027000</td>\n",
              "      <td>2.827253</td>\n",
              "      <td>0.547200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>20</td>\n",
              "      <td>0.009300</td>\n",
              "      <td>2.841008</td>\n",
              "      <td>0.547200</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-183\n",
            "Configuration saved in ./results/checkpoint-183/config.json\n",
            "Model weights saved in ./results/checkpoint-183/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-366\n",
            "Configuration saved in ./results/checkpoint-366/config.json\n",
            "Model weights saved in ./results/checkpoint-366/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-549\n",
            "Configuration saved in ./results/checkpoint-549/config.json\n",
            "Model weights saved in ./results/checkpoint-549/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-732\n",
            "Configuration saved in ./results/checkpoint-732/config.json\n",
            "Model weights saved in ./results/checkpoint-732/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-915\n",
            "Configuration saved in ./results/checkpoint-915/config.json\n",
            "Model weights saved in ./results/checkpoint-915/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-1098\n",
            "Configuration saved in ./results/checkpoint-1098/config.json\n",
            "Model weights saved in ./results/checkpoint-1098/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-1281\n",
            "Configuration saved in ./results/checkpoint-1281/config.json\n",
            "Model weights saved in ./results/checkpoint-1281/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-1464\n",
            "Configuration saved in ./results/checkpoint-1464/config.json\n",
            "Model weights saved in ./results/checkpoint-1464/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-1647\n",
            "Configuration saved in ./results/checkpoint-1647/config.json\n",
            "Model weights saved in ./results/checkpoint-1647/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-1830\n",
            "Configuration saved in ./results/checkpoint-1830/config.json\n",
            "Model weights saved in ./results/checkpoint-1830/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-2013\n",
            "Configuration saved in ./results/checkpoint-2013/config.json\n",
            "Model weights saved in ./results/checkpoint-2013/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-2196\n",
            "Configuration saved in ./results/checkpoint-2196/config.json\n",
            "Model weights saved in ./results/checkpoint-2196/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-2379\n",
            "Configuration saved in ./results/checkpoint-2379/config.json\n",
            "Model weights saved in ./results/checkpoint-2379/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-2562\n",
            "Configuration saved in ./results/checkpoint-2562/config.json\n",
            "Model weights saved in ./results/checkpoint-2562/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-2745\n",
            "Configuration saved in ./results/checkpoint-2745/config.json\n",
            "Model weights saved in ./results/checkpoint-2745/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-2928\n",
            "Configuration saved in ./results/checkpoint-2928/config.json\n",
            "Model weights saved in ./results/checkpoint-2928/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-3111\n",
            "Configuration saved in ./results/checkpoint-3111/config.json\n",
            "Model weights saved in ./results/checkpoint-3111/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-3294\n",
            "Configuration saved in ./results/checkpoint-3294/config.json\n",
            "Model weights saved in ./results/checkpoint-3294/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-3477\n",
            "Configuration saved in ./results/checkpoint-3477/config.json\n",
            "Model weights saved in ./results/checkpoint-3477/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-3660\n",
            "Configuration saved in ./results/checkpoint-3660/config.json\n",
            "Model weights saved in ./results/checkpoint-3660/pytorch_model.bin\n",
            "\n",
            "\n",
            "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
            "\n",
            "\n",
            "Loading best model from ./results/checkpoint-1647 (score: 0.5632).\n",
            "***** Running Prediction *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "# classes 8\n",
            "2915 625 625\n",
            "# classes in train 8\n",
            "# classes in dev 8\n",
            "# classes in test 8\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "loading file vocab.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/vocab.json\n",
            "loading file merges.txt from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/merges.txt\n",
            "loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/tokenizer.json\n",
            "loading file added_tokens.json from cache at None\n",
            "loading file special_tokens_map.json from cache at None\n",
            "loading file tokenizer_config.json from cache at None\n",
            "loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/config.json\n",
            "Model config RobertaConfig {\n",
            "  \"_name_or_path\": \"roberta-base\",\n",
            "  \"architectures\": [\n",
            "    \"RobertaForMaskedLM\"\n",
            "  ],\n",
            "  \"attention_probs_dropout_prob\": 0.1,\n",
            "  \"bos_token_id\": 0,\n",
            "  \"classifier_dropout\": null,\n",
            "  \"eos_token_id\": 2,\n",
            "  \"hidden_act\": \"gelu\",\n",
            "  \"hidden_dropout_prob\": 0.1,\n",
            "  \"hidden_size\": 768,\n",
            "  \"initializer_range\": 0.02,\n",
            "  \"intermediate_size\": 3072,\n",
            "  \"layer_norm_eps\": 1e-05,\n",
            "  \"max_position_embeddings\": 514,\n",
            "  \"model_type\": \"roberta\",\n",
            "  \"num_attention_heads\": 12,\n",
            "  \"num_hidden_layers\": 12,\n",
            "  \"pad_token_id\": 1,\n",
            "  \"position_embedding_type\": \"absolute\",\n",
            "  \"transformers_version\": \"4.23.1\",\n",
            "  \"type_vocab_size\": 1,\n",
            "  \"use_cache\": true,\n",
            "  \"vocab_size\": 50265\n",
            "}\n",
            "\n",
            "loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/config.json\n",
            "Model config RobertaConfig {\n",
            "  \"architectures\": [\n",
            "    \"RobertaForMaskedLM\"\n",
            "  ],\n",
            "  \"attention_probs_dropout_prob\": 0.1,\n",
            "  \"bos_token_id\": 0,\n",
            "  \"classifier_dropout\": null,\n",
            "  \"eos_token_id\": 2,\n",
            "  \"hidden_act\": \"gelu\",\n",
            "  \"hidden_dropout_prob\": 0.1,\n",
            "  \"hidden_size\": 768,\n",
            "  \"id2label\": {\n",
            "    \"0\": \"LABEL_0\",\n",
            "    \"1\": \"LABEL_1\",\n",
            "    \"2\": \"LABEL_2\",\n",
            "    \"3\": \"LABEL_3\",\n",
            "    \"4\": \"LABEL_4\",\n",
            "    \"5\": \"LABEL_5\",\n",
            "    \"6\": \"LABEL_6\",\n",
            "    \"7\": \"LABEL_7\"\n",
            "  },\n",
            "  \"initializer_range\": 0.02,\n",
            "  \"intermediate_size\": 3072,\n",
            "  \"label2id\": {\n",
            "    \"LABEL_0\": 0,\n",
            "    \"LABEL_1\": 1,\n",
            "    \"LABEL_2\": 2,\n",
            "    \"LABEL_3\": 3,\n",
            "    \"LABEL_4\": 4,\n",
            "    \"LABEL_5\": 5,\n",
            "    \"LABEL_6\": 6,\n",
            "    \"LABEL_7\": 7\n",
            "  },\n",
            "  \"layer_norm_eps\": 1e-05,\n",
            "  \"max_position_embeddings\": 514,\n",
            "  \"model_type\": \"roberta\",\n",
            "  \"num_attention_heads\": 12,\n",
            "  \"num_hidden_layers\": 12,\n",
            "  \"pad_token_id\": 1,\n",
            "  \"position_embedding_type\": \"absolute\",\n",
            "  \"transformers_version\": \"4.23.1\",\n",
            "  \"type_vocab_size\": 1,\n",
            "  \"use_cache\": true,\n",
            "  \"vocab_size\": 50265\n",
            "}\n",
            "\n",
            "loading weights file pytorch_model.bin from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/pytorch_model.bin\n",
            "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'lm_head.dense.bias', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']\n",
            "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
            "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/config.json\n",
            "Model config RobertaConfig {\n",
            "  \"architectures\": [\n",
            "    \"RobertaForMaskedLM\"\n",
            "  ],\n",
            "  \"attention_probs_dropout_prob\": 0.1,\n",
            "  \"bos_token_id\": 0,\n",
            "  \"classifier_dropout\": null,\n",
            "  \"eos_token_id\": 2,\n",
            "  \"hidden_act\": \"gelu\",\n",
            "  \"hidden_dropout_prob\": 0.1,\n",
            "  \"hidden_size\": 768,\n",
            "  \"id2label\": {\n",
            "    \"0\": \"LABEL_0\",\n",
            "    \"1\": \"LABEL_1\",\n",
            "    \"2\": \"LABEL_2\",\n",
            "    \"3\": \"LABEL_3\",\n",
            "    \"4\": \"LABEL_4\",\n",
            "    \"5\": \"LABEL_5\",\n",
            "    \"6\": \"LABEL_6\",\n",
            "    \"7\": \"LABEL_7\"\n",
            "  },\n",
            "  \"initializer_range\": 0.02,\n",
            "  \"intermediate_size\": 3072,\n",
            "  \"label2id\": {\n",
            "    \"LABEL_0\": 0,\n",
            "    \"LABEL_1\": 1,\n",
            "    \"LABEL_2\": 2,\n",
            "    \"LABEL_3\": 3,\n",
            "    \"LABEL_4\": 4,\n",
            "    \"LABEL_5\": 5,\n",
            "    \"LABEL_6\": 6,\n",
            "    \"LABEL_7\": 7\n",
            "  },\n",
            "  \"layer_norm_eps\": 1e-05,\n",
            "  \"max_position_embeddings\": 514,\n",
            "  \"model_type\": \"roberta\",\n",
            "  \"num_attention_heads\": 12,\n",
            "  \"num_hidden_layers\": 12,\n",
            "  \"pad_token_id\": 1,\n",
            "  \"position_embedding_type\": \"absolute\",\n",
            "  \"transformers_version\": \"4.23.1\",\n",
            "  \"type_vocab_size\": 1,\n",
            "  \"use_cache\": true,\n",
            "  \"vocab_size\": 50265\n",
            "}\n",
            "\n",
            "loading weights file pytorch_model.bin from cache at /root/.cache/huggingface/hub/models--roberta-base/snapshots/ff46155979338ff8063cdad90908b498ab91b181/pytorch_model.bin\n",
            "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'lm_head.dense.bias', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']\n",
            "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
            "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:310: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
            "  FutureWarning,\n",
            "***** Running training *****\n",
            "  Num examples = 2915\n",
            "  Num Epochs = 20\n",
            "  Instantaneous batch size per device = 16\n",
            "  Total train batch size (w. parallel, distributed & accumulation) = 16\n",
            "  Gradient Accumulation steps = 1\n",
            "  Total optimization steps = 3660\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='3660' max='3660' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [3660/3660 26:01, Epoch 20/20]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Accuracy</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>1.319900</td>\n",
              "      <td>1.205588</td>\n",
              "      <td>0.577600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.978900</td>\n",
              "      <td>1.116255</td>\n",
              "      <td>0.619200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.816300</td>\n",
              "      <td>1.130859</td>\n",
              "      <td>0.633600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.821600</td>\n",
              "      <td>1.280719</td>\n",
              "      <td>0.632000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>0.496800</td>\n",
              "      <td>1.309494</td>\n",
              "      <td>0.628800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>0.146600</td>\n",
              "      <td>1.554289</td>\n",
              "      <td>0.609600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>0.165500</td>\n",
              "      <td>1.672504</td>\n",
              "      <td>0.625600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>0.178600</td>\n",
              "      <td>1.918652</td>\n",
              "      <td>0.628800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>0.038300</td>\n",
              "      <td>2.086765</td>\n",
              "      <td>0.638400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>0.116000</td>\n",
              "      <td>2.294300</td>\n",
              "      <td>0.625600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>11</td>\n",
              "      <td>0.003500</td>\n",
              "      <td>2.416271</td>\n",
              "      <td>0.636800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>12</td>\n",
              "      <td>0.001500</td>\n",
              "      <td>2.489655</td>\n",
              "      <td>0.638400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>13</td>\n",
              "      <td>0.004700</td>\n",
              "      <td>2.526492</td>\n",
              "      <td>0.652800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>14</td>\n",
              "      <td>0.000900</td>\n",
              "      <td>2.617174</td>\n",
              "      <td>0.627200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>15</td>\n",
              "      <td>0.019600</td>\n",
              "      <td>2.713682</td>\n",
              "      <td>0.630400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>16</td>\n",
              "      <td>0.000700</td>\n",
              "      <td>2.661806</td>\n",
              "      <td>0.638400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>17</td>\n",
              "      <td>0.003200</td>\n",
              "      <td>2.688572</td>\n",
              "      <td>0.652800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>18</td>\n",
              "      <td>0.001600</td>\n",
              "      <td>2.701345</td>\n",
              "      <td>0.652800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>19</td>\n",
              "      <td>0.000500</td>\n",
              "      <td>2.690373</td>\n",
              "      <td>0.654400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>20</td>\n",
              "      <td>0.000500</td>\n",
              "      <td>2.695814</td>\n",
              "      <td>0.657600</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-183\n",
            "Configuration saved in ./results/checkpoint-183/config.json\n",
            "Model weights saved in ./results/checkpoint-183/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-366\n",
            "Configuration saved in ./results/checkpoint-366/config.json\n",
            "Model weights saved in ./results/checkpoint-366/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-549\n",
            "Configuration saved in ./results/checkpoint-549/config.json\n",
            "Model weights saved in ./results/checkpoint-549/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-732\n",
            "Configuration saved in ./results/checkpoint-732/config.json\n",
            "Model weights saved in ./results/checkpoint-732/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-915\n",
            "Configuration saved in ./results/checkpoint-915/config.json\n",
            "Model weights saved in ./results/checkpoint-915/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-1098\n",
            "Configuration saved in ./results/checkpoint-1098/config.json\n",
            "Model weights saved in ./results/checkpoint-1098/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-1281\n",
            "Configuration saved in ./results/checkpoint-1281/config.json\n",
            "Model weights saved in ./results/checkpoint-1281/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-1464\n",
            "Configuration saved in ./results/checkpoint-1464/config.json\n",
            "Model weights saved in ./results/checkpoint-1464/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-1647\n",
            "Configuration saved in ./results/checkpoint-1647/config.json\n",
            "Model weights saved in ./results/checkpoint-1647/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-1830\n",
            "Configuration saved in ./results/checkpoint-1830/config.json\n",
            "Model weights saved in ./results/checkpoint-1830/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-2013\n",
            "Configuration saved in ./results/checkpoint-2013/config.json\n",
            "Model weights saved in ./results/checkpoint-2013/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-2196\n",
            "Configuration saved in ./results/checkpoint-2196/config.json\n",
            "Model weights saved in ./results/checkpoint-2196/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-2379\n",
            "Configuration saved in ./results/checkpoint-2379/config.json\n",
            "Model weights saved in ./results/checkpoint-2379/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-2562\n",
            "Configuration saved in ./results/checkpoint-2562/config.json\n",
            "Model weights saved in ./results/checkpoint-2562/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-2745\n",
            "Configuration saved in ./results/checkpoint-2745/config.json\n",
            "Model weights saved in ./results/checkpoint-2745/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-2928\n",
            "Configuration saved in ./results/checkpoint-2928/config.json\n",
            "Model weights saved in ./results/checkpoint-2928/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-3111\n",
            "Configuration saved in ./results/checkpoint-3111/config.json\n",
            "Model weights saved in ./results/checkpoint-3111/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-3294\n",
            "Configuration saved in ./results/checkpoint-3294/config.json\n",
            "Model weights saved in ./results/checkpoint-3294/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-3477\n",
            "Configuration saved in ./results/checkpoint-3477/config.json\n",
            "Model weights saved in ./results/checkpoint-3477/pytorch_model.bin\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n",
            "Saving model checkpoint to ./results/checkpoint-3660\n",
            "Configuration saved in ./results/checkpoint-3660/config.json\n",
            "Model weights saved in ./results/checkpoint-3660/pytorch_model.bin\n",
            "\n",
            "\n",
            "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
            "\n",
            "\n",
            "Loading best model from ./results/checkpoint-3660 (score: 0.6576).\n",
            "***** Running Prediction *****\n",
            "  Num examples = 625\n",
            "  Batch size = 64\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/sklearn/base.py:338: UserWarning: Trying to unpickle estimator LogisticRegression from version 0.24.1 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
            "https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations\n",
            "  UserWarning,\n",
            "/usr/local/lib/python3.7/dist-packages/sklearn/base.py:338: UserWarning: Trying to unpickle estimator TfidfTransformer from version 0.24.1 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
            "https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations\n",
            "  UserWarning,\n",
            "/usr/local/lib/python3.7/dist-packages/sklearn/base.py:338: UserWarning: Trying to unpickle estimator TfidfVectorizer from version 0.24.1 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
            "https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations\n",
            "  UserWarning,\n"
          ]
        }
      ],
      "source": [
        "from collections import defaultdict\n",
        "import csv\n",
        "import pickle\n",
        "import time\n",
        "\n",
        "from datasets import load_metric\n",
        "import pandas as pd\n",
        "from sklearn.model_selection import train_test_split\n",
        "from transformers import RobertaForSequenceClassification, RobertaTokenizerFast, Trainer, TrainingArguments\n",
        "\n",
        "def compute_metrics(eval_preds):\n",
        "    metric = load_metric(\"accuracy\")\n",
        "    logits, labels = eval_preds\n",
        "    predictions = np.argmax(logits, axis=-1)\n",
        "    return metric.compute(predictions=predictions, references=labels)\n",
        "\n",
        "def multi_class_top_one_accuracy(predictions, labels, class_i):\n",
        "  \"\"\"\n",
        "  For each class, calculate the top 1 accuracy. \n",
        "  \"\"\"\n",
        "  assert len(predictions) == len(labels)\n",
        "  total = 0\n",
        "  correct = 0\n",
        "  for i in range(len(predictions)):\n",
        "    if labels[i] != class_i:\n",
        "      continue\n",
        "    total += 1\n",
        "    prediction = []\n",
        "    for j, k in enumerate(predictions[i]):\n",
        "      prediction.append([j, k]) # k is the value\n",
        "    prediction.sort(key = lambda x: -x[1])\n",
        "    if prediction[0][0] == labels[i]:\n",
        "      correct += 1\n",
        "  ans = str(round(correct/total, 3))\n",
        "  if len(ans) < 5:\n",
        "    ans += \"0\" * (5-len(ans))\n",
        "  return ans\n",
        "\n",
        "class PSCDataset(torch.utils.data.Dataset):\n",
        "    def __init__(self, encodings, labels):\n",
        "        self.encodings = encodings\n",
        "        self.labels = labels\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
        "        item['labels'] = torch.tensor(self.labels[idx])\n",
        "        return item\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.labels)\n",
        "\n",
        "start = time.time()\n",
        "mlength = 512\n",
        "directory = \"./data_and_models/\"\n",
        "\n",
        "training_args = TrainingArguments(\n",
        "    output_dir='./results',          # output directory\n",
        "    num_train_epochs=20,             # total number of training epochs\n",
        "    per_device_train_batch_size=16,  # batch size per device during training\n",
        "    per_device_eval_batch_size=64,   # batch size for evaluation\n",
        "    warmup_steps=0,                  # number of warmup steps for learning rate scheduler\n",
        "    weight_decay=0.01,               # strength of weight decay\n",
        "    logging_dir='./logs',            # directory for storing logs\n",
        "    logging_steps=10,\n",
        "    learning_rate = 2e-5,\n",
        "    save_strategy= \"epoch\",\n",
        "    evaluation_strategy=\"epoch\",\n",
        "    load_best_model_at_end= True,\n",
        "    metric_for_best_model=\"accuracy\",\n",
        "    seed = 11,\n",
        ")\n",
        "\n",
        "tasks = {\n",
        "    \"44\": {\n",
        "        \"number_of_labels\": 42,\n",
        "         \"label_column\": 1,\n",
        "    },\n",
        "    \"8\": {\n",
        "        \"number_of_labels\": 8,\n",
        "        \"label_column\": 2,\n",
        "    }\n",
        "}\n",
        "\n",
        "def compute_task(task):\n",
        "  index = -1\n",
        "  classes = {}\n",
        "  texts = []\n",
        "  labels = []\n",
        "  lm_reverse_mapper = {}\n",
        "\n",
        "\n",
        "  with open(directory + \"target_corpus.csv\") as doc:\n",
        "    reader = csv.reader(doc)\n",
        "    next(reader)\n",
        "    for row in reader:\n",
        "      topic = row[tasks[task][\"label_column\"]]\n",
        "      if topic not in classes:\n",
        "        index += 1\n",
        "        classes[topic] = index\n",
        "        lm_reverse_mapper[index] = topic.capitalize()\n",
        "      labels.append(classes[topic])\n",
        "      texts.append(row[0])\n",
        "\n",
        "  print(\"# classes\", len(classes))\n",
        "  X_train, X_test, y_train, y_test = train_test_split(texts, labels, test_size=625, random_state=11)\n",
        "  X_train, X_dev, y_train, y_dev = train_test_split(X_train, y_train, test_size=625, random_state=11)\n",
        "  print(len(X_train), len(X_dev), len(X_test))\n",
        "  print(\"# classes in train\", len(set(y_train)))\n",
        "  print(\"# classes in dev\", len(set(y_dev)))\n",
        "  print(\"# classes in test\", len(set(y_test)))\n",
        "\n",
        "  tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')\n",
        "  train_encodings = tokenizer(X_train, truncation=True, padding=True, max_length=mlength)\n",
        "  dev_encodings = tokenizer(X_dev, truncation=True, padding=True, max_length = mlength)\n",
        "  test_encodings = tokenizer(X_test, truncation=True, padding=True, max_length= mlength)\n",
        "\n",
        "  train_dataset = PSCDataset(train_encodings, y_train)\n",
        "  dev_dataset = PSCDataset(dev_encodings, y_dev)\n",
        "  test_dataset = PSCDataset(test_encodings, y_test)\n",
        "\n",
        "  def model_init():\n",
        "    return RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=tasks[task][\"number_of_labels\"])\n",
        "\n",
        "  trainer = Trainer(\n",
        "      model_init=model_init,               # the instantiated 🤗 Transformers model to be trained\n",
        "      args=training_args,                  # training arguments, defined above\n",
        "      train_dataset=train_dataset,         # training dataset\n",
        "      eval_dataset=dev_dataset,            # evaluation dataset\n",
        "      compute_metrics=compute_metrics,     # compute_metrics\n",
        "      )\n",
        "  trainer.train()\n",
        "\n",
        "  predictions = trainer.predict(test_dataset)\n",
        "\n",
        "\n",
        "  with open(directory + \"logistic_model_\" + task + \".pkl\", \"rb\") as doc:\n",
        "    model = pickle.load(doc)\n",
        "  with open(directory + \"tfidf_\" + task + \".pkl\", \"rb\") as doc:\n",
        "    tokenizer = pickle.load(doc)\n",
        "\n",
        "  class_mapper = {}\n",
        "  class_reverse_mapper = {}\n",
        "  for i, topic in enumerate(model.classes_):\n",
        "    class_mapper[topic.replace(\" \", \".\").replace(\"-\", \".\")] = i\n",
        "    class_reverse_mapper[i] = topic\n",
        "\n",
        "  df = pd.read_csv(directory + \"target_corpus.csv\")\n",
        "  df = df[df[\"text\"].isin(X_test)]\n",
        "  X = df['text']\n",
        "  Y = list(df[\"topic_\"+task].transform(lambda x: class_mapper[x]))\n",
        "\n",
        "  Xtfidf = tokenizer.transform(X)\n",
        "\n",
        "  preds = model.predict(Xtfidf)\n",
        "  preds = [class_mapper[topic.replace(\" \", \".\").replace(\"-\", \".\")] for topic in preds]\n",
        "  policy_probs = model.predict_proba(Xtfidf)\n",
        "\n",
        "  from collections import Counter\n",
        "  counter = Counter(Y)\n",
        "  results = []\n",
        "  for class_i, count in counter.items():\n",
        "    result = [class_reverse_mapper[class_i].capitalize(), count]\n",
        "    result.append(multi_class_top_one_accuracy(policy_probs, Y, class_i))\n",
        "    #sample result:\n",
        "    #['Political authority', 140, '0.550']\n",
        "    #['Welfare state expansion', 49, '0.694']\n",
        "    results.append(result)\n",
        "\n",
        "  results.sort(key = lambda result: [-result[1], result[0]])\n",
        "\n",
        "  for index, topic in lm_reverse_mapper.items():\n",
        "    topic = topic.replace(\".\", \" \")\n",
        "    if \"demographic\" in topic:\n",
        "      topic = \"Non-economic demographic groups\"\n",
        "    lm_reverse_mapper[index]=topic\n",
        "\n",
        "  lm_per_class_predictions = defaultdict(str)\n",
        "  counter = Counter(y_test)\n",
        "  for class_i, count in counter.items():\n",
        "    lm_per_class_predictions[lm_reverse_mapper[class_i]] = multi_class_top_one_accuracy(predictions.predictions, y_test, class_i)\n",
        "  outputs = []\n",
        "  for result in results:\n",
        "    result += [lm_per_class_predictions[result[0]]]\n",
        "    if float(result[-1]) > float(result[-2]):\n",
        "      result[-1] = \"\\\\textbf{\" + result[-1] + \"}\" \n",
        "    elif float(result[-1]) < float(result[-2]):\n",
        "      result[-2] = \"\\\\textbf{\" + result[-2] + \"}\" \n",
        "    str_result = [str(i) for i in result]\n",
        "    outputs.append(\"& \" +\" & \".join(str_result)+\"\\\\\\\\\")\n",
        "  return outputs\n",
        "\n",
        "results = {}\n",
        "for task in tasks:\n",
        "  results[task] = compute_task(task)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "mb8THt2M_5G3",
        "outputId": "9659a22b-5d51-4c89-b995-715941af359a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "==== 44-Topic Classification ====\n",
            "& Political authority & 140 & 0.550 & \\textbf{0.657}\\\\\n",
            "& Welfare state expansion & 49 & 0.694 & \\textbf{0.714}\\\\\n",
            "& Democracy & 44 & 0.318 & \\textbf{0.341}\\\\\n",
            "& No topic & 32 & 0.000 & \\textbf{0.438}\\\\\n",
            "& Labour groups & 31 & 0.387 & \\textbf{0.484}\\\\\n",
            "& Education & 26 & \\textbf{0.885} & 0.846\\\\\n",
            "& Constitutionalism & 24 & 0.000 & \\textbf{0.458}\\\\\n",
            "& Economic orthodoxy & 21 & 0.238 & \\textbf{0.571}\\\\\n",
            "& Governmental and administrative efficiency & 21 & 0.238 & 0.238\\\\\n",
            "& Technology and infrastructure & 21 & 0.333 & \\textbf{0.524}\\\\\n",
            "& Law and order & 20 & 0.650 & \\textbf{0.700}\\\\\n",
            "& Multiculturalism & 19 & 0.632 & \\textbf{0.842}\\\\\n",
            "& Equality & 18 & \\textbf{0.389} & 0.278\\\\\n",
            "& Free market economy & 15 & 0.000 & \\textbf{0.267}\\\\\n",
            "& Economic growth & 13 & 0.615 & \\textbf{0.769}\\\\\n",
            "& Freedom and human rights & 13 & 0.000 & \\textbf{0.231}\\\\\n",
            "& Market regulation & 12 & 0.167 & \\textbf{0.333}\\\\\n",
            "& Traditional morality & 12 & 0.250 & \\textbf{0.333}\\\\\n",
            "& Military & 11 & 0.727 & \\textbf{0.909}\\\\\n",
            "& National way of life & 10 & 0.300 & 0.300\\\\\n",
            "& Political corruption & 10 & 0.100 & \\textbf{0.200}\\\\\n",
            "& Protectionism & 10 & 0.200 & \\textbf{0.600}\\\\\n",
            "& Centralization & 9 & 0.111 & \\textbf{0.222}\\\\\n",
            "& Environmental protection & 9 & 0.667 & \\textbf{1.000}\\\\\n",
            "& Agriculture and farmers & 7 & \\textbf{0.714} & 0.571\\\\\n",
            "& Incentives & 7 & 0.571 & 0.571\\\\\n",
            "& Civic mindedness & 6 & 0.000 & 0.000\\\\\n",
            "& Nationalisation & 5 & \\textbf{0.400} & 0.200\\\\\n",
            "& Culture & 3 & 0.000 & \\textbf{0.667}\\\\\n",
            "& Internationalism & 2 & 0.000 & \\textbf{0.500}\\\\\n",
            "& Controlled economy & 1 & 0.000 & 0.000\\\\\n",
            "& Middle class and professional groups & 1 & 0.000 & 0.000\\\\\n",
            "& Non-economic demographic groups & 1 & 1.000 & 1.000\\\\\n",
            "& Peace & 1 & 0.000 & 0.000\\\\\n",
            "& Underprivileged minority groups & 1 & \\textbf{1.000} & 0.000\\\\\n",
            "==== 8-Topic Classification ====\n",
            "& Political system & 180 & 0.556 & \\textbf{0.622}\\\\\n",
            "& Economy & 105 & 0.600 & \\textbf{0.705}\\\\\n",
            "& Welfare and quality of life & 105 & 0.667 & \\textbf{0.810}\\\\\n",
            "& Freedom and democracy & 81 & 0.284 & \\textbf{0.556}\\\\\n",
            "& Fabric of society & 67 & \\textbf{0.582} & 0.522\\\\\n",
            "& Social groups & 41 & 0.415 & \\textbf{0.537}\\\\\n",
            "& No topic & 32 & 0.000 & \\textbf{0.344}\\\\\n",
            "& External relations & 14 & 0.571 & \\textbf{0.857}\\\\\n"
          ]
        }
      ],
      "source": [
        "for task in tasks:\n",
        "  print(\"==== \" + task + \"-Topic Classification ====\")\n",
        "  for result in results[task]:\n",
        "    print(result)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "0xqaBF7uZmZQ",
        "outputId": "1a24b909-7a0f-4f20-d777-6b7be7583b1b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The program took 52.0 minutes in total.\n"
          ]
        }
      ],
      "source": [
        "end = time.time()\n",
        "print(f\"The program took {(end - start) // 60} minutes in total.\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "asJkHQIfZnGk"
      },
      "outputs": [],
      "source": [
        "from google.colab import runtime\n",
        "runtime.unassign()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "gpuClass": "premium",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}